package encryption

import (
	"bytes"
	"encoding/base64"
	"fmt"

	"github.com/tink-crypto/tink-go/aead"
	"github.com/tink-crypto/tink-go/insecurecleartextkeyset"
	"github.com/tink-crypto/tink-go/jwt"
	"github.com/tink-crypto/tink-go/keyset"
	"github.com/tink-crypto/tink-go/tink"
)

type localEncryptionService struct {
	key                *aead.KMSEnvelopeAEAD
	privateEc256Handle *keyset.Handle
	publicEc256Handle  *keyset.Handle
}

// NewLocalEncryption creates a new local encryption service. keysetBytes is the raw keyset in
// base64-encoded JSON format. This can be generated by calling hatchet-admin keyset create-local.
func NewLocalEncryption(masterKey []byte, privateEc256 []byte, publicEc256 []byte) (*localEncryptionService, error) {
	// get the master keyset handle
	aes256GcmHandle, err := insecureHandleFromBytes(masterKey)

	if err != nil {
		return nil, err
	}

	a, err := aead.New(aes256GcmHandle)

	if err != nil {
		return nil, err
	}

	privateEc256Handle, err := handleFromBytes(privateEc256, a)

	if err != nil {
		return nil, err
	}

	publicEc256Handle, err := handleFromBytes(publicEc256, a)

	if err != nil {
		return nil, err
	}

	envelope := aead.NewKMSEnvelopeAEAD2(aead.AES128GCMKeyTemplate(), a)

	if envelope == nil {
		return nil, fmt.Errorf("failed to create envelope")
	}

	return &localEncryptionService{
		key:                envelope,
		privateEc256Handle: privateEc256Handle,
		publicEc256Handle:  publicEc256Handle,
	}, nil
}

func GenerateLocalKeys() (masterKey []byte, privateEc256 []byte, publicEc256 []byte, err error) {
	masterKey, masterHandle, err := generateLocalMasterKey()

	if err != nil {
		return nil, nil, nil, err
	}

	a, err := aead.New(masterHandle)

	if err != nil {
		return nil, nil, nil, err
	}

	privateEc256, publicEc256, err = generateJWTKeysets(a)

	if err != nil {
		return nil, nil, nil, err
	}

	return masterKey, privateEc256, publicEc256, nil
}

func generateLocalMasterKey() ([]byte, *keyset.Handle, error) {
	aeadTemplate := aead.AES256GCMKeyTemplate()

	aes256GcmHandle, err := keyset.NewHandle(aeadTemplate)

	if err != nil {
		return nil, nil, fmt.Errorf("failed to create new keyset handle with AES256GCM template: %w", err)
	}

	bytes, err := insecureBytesFromHandle(aes256GcmHandle)

	if err != nil {
		return nil, nil, fmt.Errorf("failed to get bytes from handle: %w", err)
	}

	return bytes, aes256GcmHandle, nil
}

// generateJWTKeysets creates the keysets for JWT signing and verification encrypted with the
// masterKey. The masterKey can be from a remote KMS service or a local keyset.
func generateJWTKeysets(masterKey tink.AEAD) (privateEc256 []byte, publicEc256 []byte, err error) {
	privateHandle, err := keyset.NewHandle(jwt.ES256Template())

	if err != nil {
		err = fmt.Errorf("failed to create new keyset handle with ES256 template: %w", err)
		return
	}

	privateEc256, err = bytesFromHandle(privateHandle, masterKey)

	if err != nil {
		return
	}

	publicHandle, err := privateHandle.Public()

	if err != nil {
		err = fmt.Errorf("failed to get public keyset: %w", err)
		return
	}

	publicEc256, err = bytesFromHandle(publicHandle, masterKey)

	if err != nil {
		return
	}

	return
}

// bytesFromHandle returns the encrypted keyset in base64-encoded JSON format, encrypted with the
// masterKey
func bytesFromHandle(kh *keyset.Handle, masterKey tink.AEAD) ([]byte, error) {
	buf := new(bytes.Buffer)
	writer := keyset.NewJSONWriter(buf)
	err := kh.Write(writer, masterKey)

	if err != nil {
		return nil, fmt.Errorf("failed to write keyset: %w", err)
	}

	// base64-encode bytes
	keysetBytes := make([]byte, base64.RawStdEncoding.EncodedLen(len(buf.Bytes())))
	base64.RawStdEncoding.Encode(keysetBytes, buf.Bytes())

	return keysetBytes, nil
}

// insecureBytesFromHandle returns the raw (unencrypted) keyset in base64-encoded JSON format.
func insecureBytesFromHandle(kh *keyset.Handle) ([]byte, error) {
	buf := new(bytes.Buffer)
	writer := keyset.NewJSONWriter(buf)
	err := insecurecleartextkeyset.Write(kh, writer)

	if err != nil {
		return nil, fmt.Errorf("failed to write keyset: %w", err)
	}

	// base64-encode bytes
	keysetBytes := make([]byte, base64.RawStdEncoding.EncodedLen(len(buf.Bytes())))
	base64.RawStdEncoding.Encode(keysetBytes, buf.Bytes())

	return keysetBytes, nil
}

func handleFromBytes(keysetBytes []byte, masterKey tink.AEAD) (*keyset.Handle, error) {
	// base64-decode bytes
	keysetJsonBytes := make([]byte, base64.RawStdEncoding.DecodedLen(len(keysetBytes)))
	_, err := base64.RawStdEncoding.Decode(keysetJsonBytes, keysetBytes)

	if err != nil {
		return nil, fmt.Errorf("failed to decode keyset bytes: %w", err)
	}

	// read keyset
	handle, err := keyset.Read(keyset.NewJSONReader(bytes.NewReader(keysetJsonBytes)), masterKey)

	if err != nil {
		return nil, fmt.Errorf("failed to read keyset: %w", err)
	}

	return handle, nil
}

func insecureHandleFromBytes(keysetBytes []byte) (*keyset.Handle, error) {
	// base64-decode bytes
	keysetJsonBytes := make([]byte, base64.RawStdEncoding.DecodedLen(len(keysetBytes)))
	_, err := base64.RawStdEncoding.Decode(keysetJsonBytes, keysetBytes)

	if err != nil {
		return nil, fmt.Errorf("failed to decode keyset bytes: %w", err)
	}

	// read keyset
	handle, err := insecurecleartextkeyset.Read(keyset.NewJSONReader(bytes.NewReader(keysetJsonBytes)))

	if err != nil {
		return nil, fmt.Errorf("failed to read keyset: %w", err)
	}

	return handle, nil
}

func (svc *localEncryptionService) Encrypt(plaintext []byte, dataId string) ([]byte, error) {
	return encrypt(svc.key, plaintext, dataId)
}

func (svc *localEncryptionService) Decrypt(ciphertext []byte, dataId string) ([]byte, error) {
	return decrypt(svc.key, ciphertext, dataId)
}

func (svc *localEncryptionService) EncryptString(data string, dataId string) (string, error) {
	b, err := encrypt(svc.key, []byte(data), dataId)
	if err != nil {
		return "", err
	}
	return base64.StdEncoding.EncodeToString(b), nil
}

func (svc *localEncryptionService) DecryptString(data string, dataId string) (string, error) {
	plain, err := base64.StdEncoding.DecodeString(data)
	if err != nil {
		return "", err
	}
	b, err := decrypt(svc.key, plain, dataId)
	if err != nil {
		return "", err
	}
	return string(b), nil
}

func (svc *localEncryptionService) GetPrivateJWTHandle() *keyset.Handle {
	return svc.privateEc256Handle
}

func (svc *localEncryptionService) GetPublicJWTHandle() *keyset.Handle {
	return svc.publicEc256Handle
}
